{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Using PyTorch-based models in GluonTS"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook illustrates how one can implement a time series model using PyTorch, and use it together with the rest of the GluonTS ecosystem for data loading, feature processing, and model evaluation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import List, Optional, Callable\n",
    "from itertools import islice"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from matplotlib import pyplot as plt\n",
    "import matplotlib.dates as mdates"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gluonts.dataset.field_names import FieldName\n",
    "from gluonts.dataset.loader import TrainDataLoader\n",
    "from gluonts.dataset.repository.datasets import get_dataset\n",
    "from gluonts.evaluation import Evaluator\n",
    "from gluonts.evaluation.backtest import make_evaluation_predictions\n",
    "from gluonts.torch.batchify import batchify \n",
    "from gluonts.torch.support.util import copy_parameters\n",
    "from gluonts.torch.model.predictor import PyTorchPredictor\n",
    "from gluonts.transform import Chain, AddObservedValuesIndicator, InstanceSplitter, ExpectedNumInstanceSampler"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For this example we will use the \"electricity\" dataset, which can be loaded as follows."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = get_dataset(\"electricity\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is what the first time series from the training portion of the dataset look like:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "date_formater = mdates.DateFormatter('%Y')\n",
    "\n",
    "fig = plt.figure(figsize=(12,8))\n",
    "for idx, entry in enumerate(islice(dataset.train, 9)):\n",
    "    ax = plt.subplot(3, 3, idx+1)\n",
    "    t = pd.date_range(start=entry[\"start\"], periods=len(entry[\"target\"]), freq=entry[\"start\"].freq)\n",
    "    plt.plot(t, entry[\"target\"])\n",
    "    plt.xticks(pd.date_range(start=pd.to_datetime(\"2011-12-31\"), periods=3, freq=\"AS\"))\n",
    "    ax.xaxis.set_major_formatter(date_formater)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Probabilistic feed-forward network using PyTorch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will use a pretty simple model, based on a feed-forward network whose output layer produces the parameters of a Student's t-distribution at each time step in the prediction range. We will define two networks based on this idea:\n",
    "* The `TrainingFeedForwardNetwork` computes the loss associated with given observations, i.e. the negative log-likelihood of the observations according to the output distribution; this will be used during training.\n",
    "* The `SamplingFeedForwardNetwork` will be used at inference time: this uses the output distribution to draw a sample of a given size, as a way to encode the predicted distribution."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def mean_abs_scaling(context, min_scale=1e-5):\n",
    "    return context.abs().mean(1).clamp(min_scale, None).unsqueeze(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def no_scaling(context):\n",
    "    return torch.ones(context.shape[0], 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TrainingFeedForwardNetwork(nn.Module):\n",
    "    distr_type = torch.distributions.StudentT\n",
    "    \n",
    "    def __init__(\n",
    "        self,\n",
    "        prediction_length: int,\n",
    "        context_length: int,\n",
    "        hidden_dimensions: List[int],\n",
    "        batch_norm: bool=False,\n",
    "        scaling: Callable=mean_abs_scaling,\n",
    "    ) -> None:\n",
    "        super().__init__()\n",
    "        \n",
    "        assert prediction_length > 0\n",
    "        assert context_length > 0\n",
    "        assert len(hidden_dimensions) > 0\n",
    "        \n",
    "        self.prediction_length = prediction_length\n",
    "        self.context_length = context_length\n",
    "        self.hidden_dimensions = hidden_dimensions\n",
    "        self.batch_norm = batch_norm\n",
    "        self.scaling = scaling\n",
    "        \n",
    "        dimensions = [context_length] + hidden_dimensions[:-1]\n",
    "\n",
    "        modules = []\n",
    "        for in_size, out_size in zip(dimensions[:-1], dimensions[1:]):\n",
    "            modules += [self.__make_lin(in_size, out_size), nn.ReLU()]\n",
    "            if batch_norm:\n",
    "                modules.append(nn.BatchNorm1d(units))\n",
    "        modules.append(self.__make_lin(dimensions[-1], prediction_length * hidden_dimensions[-1]))\n",
    "        self.nn = nn.Sequential(*modules)\n",
    "        \n",
    "        self.df_proj = nn.Sequential(self.__make_lin(hidden_dimensions[-1], 1), nn.Softplus())\n",
    "        self.loc_proj = self.__make_lin(hidden_dimensions[-1], 1)\n",
    "        self.scale_proj = nn.Sequential(self.__make_lin(hidden_dimensions[-1], 1), nn.Softplus())\n",
    "    \n",
    "    @staticmethod\n",
    "    def __make_lin(dim_in, dim_out):\n",
    "        lin = nn.Linear(dim_in, dim_out)\n",
    "        torch.nn.init.uniform_(lin.weight, -0.07, 0.07)\n",
    "        torch.nn.init.zeros_(lin.bias)\n",
    "        return lin\n",
    "    \n",
    "    def distr_and_scale(self, context):\n",
    "        scale = self.scaling(context)\n",
    "        scaled_context = context / scale\n",
    "        nn_out = self.nn(scaled_context)\n",
    "        nn_out_reshaped = nn_out.reshape(-1, self.prediction_length, self.hidden_dimensions[-1])\n",
    "        \n",
    "        distr_args = (\n",
    "            2.0 + self.df_proj(nn_out_reshaped).squeeze(dim=-1),\n",
    "            self.loc_proj(nn_out_reshaped).squeeze(dim=-1),\n",
    "            self.scale_proj(nn_out_reshaped).squeeze(dim=-1),\n",
    "        )\n",
    "        distr = net.distr_type(*distr_args)\n",
    "        \n",
    "        return distr, scale\n",
    "    \n",
    "    def forward(self, context, target):\n",
    "        assert context.shape[-1] == self.context_length\n",
    "        assert target.shape[-1] == self.prediction_length\n",
    "        \n",
    "        distr, scale = self.distr_and_scale(context)\n",
    "        loss = (-distr.log_prob(target / scale) + torch.log(scale)).mean(dim=1)\n",
    "        \n",
    "        return loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SamplingFeedForwardNetwork(TrainingFeedForwardNetwork):\n",
    "    def __init__(self, *args, num_samples: int = 1000, **kwargs):\n",
    "        super().__init__(*args, **kwargs)\n",
    "        self.num_samples = num_samples\n",
    "        \n",
    "    def forward(self, context):\n",
    "        assert context.shape[-1] == self.context_length\n",
    "        \n",
    "        distr, scale = self.distr_and_scale(context)\n",
    "        sample = distr.sample((self.num_samples, )) * scale\n",
    "        \n",
    "        return sample.permute(1, 0, 2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can now instantiate the training network, and explore its set of parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "context_length = 2 * 7 * 24\n",
    "prediction_length = dataset.metadata.prediction_length\n",
    "hidden_dimensions = [96, 48]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 32\n",
    "num_batches_per_epoch = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = TrainingFeedForwardNetwork(\n",
    "    prediction_length=prediction_length,\n",
    "    context_length=context_length,\n",
    "    hidden_dimensions=hidden_dimensions,\n",
    "    batch_norm=False,\n",
    "    scaling=mean_abs_scaling,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sum(np.prod(p.shape) for p in net.parameters())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for p in net.parameters():\n",
    "    print(p.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Defining the training data loader"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We now set up the data loader which will yield batches of data to train on. Starting from the original dataset, the data loader is configured to apply the following transformation, which does essentially two things:\n",
    "* Replaces `nan`s in the target field with a dummy value (zero), and adds a field indicating which values were actually observed vs imputed this way.\n",
    "* Slices out training instances of a fixed length randomly from the given dataset; these will be stacked into batches by the data loader itself."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transformation = Chain([\n",
    "    AddObservedValuesIndicator(\n",
    "        target_field=FieldName.TARGET,\n",
    "        output_field=FieldName.OBSERVED_VALUES,\n",
    "    ),\n",
    "    InstanceSplitter(\n",
    "        target_field=FieldName.TARGET,\n",
    "        is_pad_field=FieldName.IS_PAD,\n",
    "        start_field=FieldName.START,\n",
    "        forecast_start_field=FieldName.FORECAST_START,\n",
    "        train_sampler=ExpectedNumInstanceSampler(num_instances=1),\n",
    "        past_length=context_length,\n",
    "        future_length=prediction_length,\n",
    "        time_series_fields=[FieldName.OBSERVED_VALUES],\n",
    "    ),\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_loader = TrainDataLoader(\n",
    "    dataset.train,\n",
    "    batch_size=batch_size,\n",
    "    stack_fn=batchify,\n",
    "    transform=transformation,\n",
    "    num_batches_per_epoch=num_batches_per_epoch\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train the model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can now train the model using any of the available optimizers from PyTorch:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "optimizer = torch.optim.Adam(net.parameters())\n",
    "\n",
    "for epoch_no in range(10):\n",
    "    sum_epoch_loss = 0.0\n",
    "    for batch_no, batch in enumerate(data_loader, start=1):\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        context = batch[\"past_target\"]\n",
    "        target = batch[\"future_target\"]\n",
    "        loss_vec = net(context, target)\n",
    "        loss = loss_vec.mean()\n",
    "        loss.backward()\n",
    "        \n",
    "        optimizer.step()\n",
    "        \n",
    "        sum_epoch_loss += loss.detach().numpy().item()\n",
    "        \n",
    "    print(f\"{epoch_no}: {sum_epoch_loss / num_batches_per_epoch}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create predictor out of the trained model, and test it"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We now have a trained model, whose parameters can be copied over to a `SamplingFeedForwardNetwork` object: we will wrap this into a `PyTorchPredictor` that can be used for inference tasks."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_net = SamplingFeedForwardNetwork(\n",
    "    prediction_length=net.prediction_length,\n",
    "    context_length=net.context_length,\n",
    "    hidden_dimensions=net.hidden_dimensions,\n",
    "    batch_norm=net.batch_norm,\n",
    ")\n",
    "copy_parameters(net, pred_net)\n",
    "\n",
    "predictor_pytorch = PyTorchPredictor(\n",
    "    prediction_length=prediction_length, freq = dataset.metadata.freq, \n",
    "    input_names = [\"past_target\"], prediction_net=pred_net, batch_size=32, input_transform=transformation,\n",
    "    device=None\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For example, we can do backtesting on the test dataset: in what follows, `make_evaluation_predictions` will slice out the trailing `prediction_length` observations from the test time series, and use the given predictor to obtain forecasts for the same time range."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "forecast_it, ts_it = make_evaluation_predictions(\n",
    "    dataset=dataset.test,\n",
    "    predictor=predictor_pytorch,\n",
    "    num_samples=1000,\n",
    ")\n",
    "\n",
    "forecasts_pytorch = list(forecast_it)\n",
    "tss_pytorch = list(ts_it)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Once we have the forecasts, we can plot them:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(20, 15))\n",
    "date_formater = mdates.DateFormatter('%b, %d')\n",
    "plt.rcParams.update({'font.size': 15})\n",
    "\n",
    "for idx, (forecast, ts) in islice(enumerate(zip(forecasts_pytorch, tss_pytorch)), 9):\n",
    "    ax =plt.subplot(3, 3, idx+1)\n",
    "    \n",
    "    plt.plot(ts[-5 * prediction_length:], label=\"target\")\n",
    "    forecast.plot()\n",
    "    plt.xticks(rotation=60)\n",
    "    ax.xaxis.set_major_formatter(date_formater)\n",
    "    \n",
    "plt.gcf().tight_layout()\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And we can compute evaluation metrics, that summarize the performance of the model on our test data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "evaluator = Evaluator(quantiles=[0.1, 0.5, 0.9])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "metrics_pytorch, _ = evaluator(iter(tss_pytorch), iter(forecasts_pytorch), num_series=len(dataset.test))\n",
    "pd.DataFrame.from_records(metrics_pytorch, index=[\"FeedForward\"]).transpose()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
